Load the necessary libraries
library(gbm) #for gradient boosted models
library(car)
library(dismo)
library(pdp)
library(ggfortify)
library(randomForest)
library(tidyverse)
library(gridExtra)
library(patchwork)
Verneaux (1973) measured the abundance of 27 fish species and 11 environmental variables from 30 locations along the Doubs river which runs along the France-Switzerland border. The environmental data comprised a range of hydrology, geomorphology and chemistry parameters and amungst other things, the ichthyologist was interested in relating fish abundances to the environmental drivers.
The data are in two files.
verneaux.fish.csv - the abundance of 27 fish
speciesverneaux.env.csv - the environmental data| Variable | Description |
|---|---|
| DAS | Distance from source (km) |
| ALT | Altitude (m above sea level) |
| PEN | Slope (per thousand) |
| DEB | Mean minimum dischange (m3.s-1) |
| PH | Water pH |
| DUR | Calcium concentration (hardness) (mgL^-1) |
| PHO | Phosphorus concentration (mgL^-1) |
| NIT | Nitrate concentration (mgL^-1) |
| AMM | Ammonium concentration (mgL^-1) |
| OXY | Dissolved oxygen (mgL^-1) |
| BDO | Biological oxygen demand (mgL^-1) |
fish = read_csv('../public/data/verneaux.fish.csv', trim_ws=TRUE)
glimpse(fish)
## Rows: 30
## Columns: 27
## $ CHA <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 2, 1, 1, 0, 0, 0, 0, …
## $ TRU <dbl> 3, 5, 5, 4, 2, 3, 5, 0, 0, 1, 3, 5, 5, 5, 4, 3, 2, 1, 0, 0, 0, 0, …
## $ VAI <dbl> 0, 4, 5, 5, 3, 4, 4, 0, 1, 4, 4, 4, 5, 5, 4, 3, 4, 3, 3, 1, 1, 0, …
## $ LOC <dbl> 0, 3, 5, 5, 2, 5, 5, 0, 3, 4, 1, 4, 2, 4, 5, 5, 4, 3, 5, 2, 1, 1, …
## $ OMB <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 2, 0, 1, 1, 0, 0, 0, 0, …
## $ BLA <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 4, 5, 2, 1, 1, 0, 0, 0, …
## $ HOT <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 3, …
## $ TOX <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 3, 3, 2, 2, 2, …
## $ VAN <dbl> 0, 0, 0, 0, 5, 1, 1, 0, 0, 2, 0, 0, 0, 0, 3, 5, 3, 2, 2, 2, 2, 3, …
## $ CHE <dbl> 0, 0, 0, 1, 2, 2, 1, 0, 5, 2, 1, 1, 0, 1, 3, 2, 2, 3, 1, 3, 2, 4, …
## $ BAR <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 2, 4, 4, 5, …
## $ SPI <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4, 3, 2, 3, 2, 1, …
## $ GOU <dbl> 0, 0, 0, 1, 2, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 1, 2, 4, 4, 5, 5, …
## $ BRO <dbl> 0, 0, 1, 2, 4, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 2, 3, 3, …
## $ PER <dbl> 0, 0, 0, 2, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 1, 2, 3, 4, …
## $ BOU <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 3, …
## $ PSO <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, …
## $ ROT <dbl> 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, …
## $ CAR <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 3, …
## $ TAN <dbl> 0, 0, 0, 1, 3, 2, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 4, 4, 4, …
## $ BCO <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 3, 4, …
## $ PCH <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, …
## $ GRE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3, 4, …
## $ GAR <dbl> 0, 0, 0, 0, 5, 1, 0, 0, 4, 0, 0, 0, 0, 0, 0, 1, 2, 2, 5, 5, 5, 5, …
## $ BBO <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, …
## $ ABL <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 5, 5, 5, …
## $ ANG <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, …
env = read_csv('../public/data/verneaux.env.csv', trim_ws=TRUE)
glimpse(env)
## Rows: 30
## Columns: 11
## $ DAS <dbl> 0.3, 2.2, 10.2, 18.5, 21.5, 32.4, 36.8, 49.1, 70.5, 99.0, 123.4, 1…
## $ ALT <dbl> 934, 932, 914, 854, 849, 846, 841, 792, 752, 617, 483, 477, 450, 4…
## $ PEN <dbl> 48.0, 3.0, 3.7, 3.2, 2.3, 3.2, 6.6, 2.5, 1.2, 9.9, 4.1, 1.6, 2.1, …
## $ DEB <dbl> 0.84, 1.00, 1.80, 2.53, 2.64, 2.86, 4.00, 1.30, 4.80, 10.00, 19.90…
## $ PH <dbl> 7.9, 8.0, 8.3, 8.0, 8.1, 7.9, 8.1, 8.1, 8.0, 7.7, 8.1, 7.9, 8.1, 8…
## $ DUR <dbl> 45, 40, 52, 72, 84, 60, 88, 94, 90, 82, 96, 86, 98, 98, 86, 88, 92…
## $ PHO <dbl> 0.01, 0.02, 0.05, 0.10, 0.38, 0.20, 0.07, 0.20, 0.30, 0.06, 0.30, …
## $ NIT <dbl> 0.20, 0.20, 0.22, 0.21, 0.52, 0.15, 0.15, 0.41, 0.82, 0.75, 1.60, …
## $ AMM <dbl> 0.00, 0.10, 0.05, 0.00, 0.20, 0.00, 0.00, 0.12, 0.12, 0.01, 0.00, …
## $ OXY <dbl> 12.2, 10.3, 10.5, 11.0, 8.0, 10.2, 11.1, 7.0, 7.2, 10.0, 11.5, 12.…
## $ DBO <dbl> 2.7, 1.9, 3.5, 1.3, 6.2, 5.3, 2.2, 8.1, 5.2, 4.3, 2.7, 3.0, 2.4, 3…
For this example, we are going to focus on total fish abundance and therefore we need to sum the total number of fish within each location (rows).
fish <- fish %>% rowSums
car::scatterplotMatrix(cbind(fish, env))
Now we will specify the gradient boosted model with:
The values of n.trees, interaction.depth
and shrinkage used below are purely based on what are
typically good starting points. Nevertheless, they will be data set
dependent. We will start off with those values and then evaluate whether
they are appropriate.
fish.gbm = gbm(fish ~ DAS + ALT + PEN + DEB + PH + DUR + PHO + NIT + AMM +
OXY + DBO,
data=env,
var.monotone = c(1,-1,-1,1,0,1,0,0,0,0,0),
distribution='poisson',
n.trees=10000,
n.minobsinnode = 2,
interaction.depth=5,
bag.fraction=0.5,
shrinkage=0.01,
train.fraction=1,
cv.folds=3)
We will now determine the optimum number of trees estimated to be required in order to achieve a balance between bias (biased towards the exact observations) and precision (variability in estimates). Ideally, the optimum number of trees should be close to 1000. If it is much less (as in this case), it could imply that the tree learned too quickly. On the other hand, if the optimum number of trees is very close to the total number of fitted trees, then it suggests that the optimum may not actually have occured yet and that more trees should be used (or a faster learning rate).
(best.iter = gbm.perf(fish.gbm,method='OOB'))
## [1] 262
## attr(,"smoother")
## Call:
## loess(formula = object$oobag.improve ~ x, enp.target = min(max(4,
## length(x)/10), 50))
##
## Number of Observations: 10000
## Equivalent Number of Parameters: 39.99
## Residual Standard Error: 0.003112
(best.iter = gbm.perf(fish.gbm,method='cv'))
## [1] 525
Conclusions:
fish.gbm = gbm(fish ~ DAS + ALT + PEN + DEB + PH + DUR + PHO + NIT + AMM +
OXY + DBO,
data=env,
var.monotone = c(1,-1,-1,1,0,1,0,0,0,0,0),
distribution='poisson',
n.trees=10000,
n.minobsinnode = 2,
interaction.depth=5,
bag.fraction=0.5,
shrinkage=0.001,
train.fraction=1,
cv.folds=3)
(best.iter = gbm.perf(fish.gbm,method='OOB'))
## [1] 2589
## attr(,"smoother")
## Call:
## loess(formula = object$oobag.improve ~ x, enp.target = min(max(4,
## length(x)/10), 50))
##
## Number of Observations: 10000
## Equivalent Number of Parameters: 39.99
## Residual Standard Error: 0.0008918
(best.iter = gbm.perf(fish.gbm,method='cv'))
## [1] 3008
Conclusions:
If a predictor is an important driver of the patterns in the response, then many of the tree splits should feature this predictor. It thus follows that the number of proportion of total splits that features each predictor will be a measure of the relative influence of each of the predictors.
summary(fish.gbm, n.trees=best.iter)
Conclusions:
attr(fish.gbm$Terms,"term.labels")
## [1] "DAS" "ALT" "PEN" "DEB" "PH" "DUR" "PHO" "NIT" "AMM" "OXY" "DBO"
plot(fish.gbm, 1, n.tree=best.iter)
plot(fish.gbm, 2, n.tree=best.iter)
fish.gbm %>%
pdp::partial(pred.var='DAS',
n.trees=best.iter,
recursive=FALSE,
inv.link=exp) %>%
autoplot()
Recursive indicates that a weighted tree traversal method described by Friedman 2001 (which is very fast) should be used (only works for gbm). Otherwise a slower brute force method is used. If want to back transform - need to use brute force.
nms <- attr(fish.gbm$Terms,"term.labels")
p <- vector('list', length(nms))
names(p) <- nms
for (nm in nms) {
print(nm)
p[[nm]] <- fish.gbm %>% pdp::partial(pred.var=nm,
n.trees=best.iter,
inv.link=exp,
recursive=FALSE,
type='regression') %>%
autoplot() +
ylim(18, 60)
}
## [1] "DAS"
## [1] "ALT"
## [1] "PEN"
## [1] "DEB"
## [1] "PH"
## [1] "DUR"
## [1] "PHO"
## [1] "NIT"
## [1] "AMM"
## [1] "OXY"
## [1] "DBO"
patchwork::wrap_plots(p)
#do.call('grid.arrange', p)
We might also want to explore interactions…
fish.gbm %>%
pdp::partial(pred.var=c('DAS'),
n.trees=best.iter, recursive=FALSE, inv.link=exp) %>%
autoplot()
fish.gbm %>%
pdp::partial(pred.var=c('DAS','ALT'),
n.trees=best.iter, recursive=TRUE) %>%
autoplot()
g1 = fish.gbm %>% pdp::partial(pred.var='DAS', n.trees=best.iter,
recursive=FALSE,inv.link=exp) %>%
autoplot
g2 = fish.gbm %>% pdp::partial(pred.var='ALT', n.trees=best.iter,
recursive=FALSE,inv.link=exp) %>%
autoplot
g1 + g2
fish.acc <- env %>%
bind_cols(Pred = predict(fish.gbm,
newdata=env,
n.tree=best.iter,
type='response'))
with(fish.acc, cor(fish, Pred))
## [1] 0.9905805
fish.acc %>%
ggplot() +
geom_point(aes(y=Pred, x=fish))
Computes Friedman’s H-statistic to assess the strength of variable interactions. This measures the relative strength of interactions in models It is on a scale of 0-1, where 1 is very strong interaction In y=β_0+β_1x_1+β_2x_2+β_3x_3.. H= If both main effects are weak, then the H- stat will be unstable.. and could indicate a strong interaction.
What were the strong main effects: - DAS - ALT - OXY
attr(fish.gbm$Terms,"term.labels")
## [1] "DAS" "ALT" "PEN" "DEB" "PH" "DUR" "PHO" "NIT" "AMM" "OXY" "DBO"
interact.gbm(fish.gbm, env,c(1,2), n.tree=best.iter)
## [1] 0.03754515
interact.gbm(fish.gbm, env,c(1,10), n.tree=best.iter)
## [1] 0.1102318
interact.gbm(fish.gbm, env,c(2,10), n.tree=best.iter)
## [1] 0.05219991
interact.gbm(fish.gbm, env,c(1,2,10), n.tree=best.iter)
## [1] 0.003372834
fish.gbm %>% pdp::partial(pred.var=c('DAS', 'OXY'), n.trees=best.iter, recursive=FALSE) %>%
autoplot
fish.gbm %>% pdp::partial(pred.var=c('DAS', 'OXY'),
n.trees=best.iter,
recursive=FALSE,
inv.link = exp) %>%
autoplot
fish.gbm %>% pdp::partial(pred.var=c(1, 10), n.trees=best.iter, recursive=FALSE) %>% autoplot
fish.gbm %>% pdp::partial(pred.var=c(2, 10), n.trees=best.iter, recursive=FALSE) %>% autoplot
fish.grid = plot(fish.gbm, c(1,10), n.tree=best.iter, return.grid=TRUE)
head(fish.grid)
ggplot(fish.grid, aes(y=DAS, x=OXY)) +
geom_tile(aes(fill=y)) +
geom_contour(aes(z=y)) +
scale_fill_gradientn(colors=heat.colors(10))
## [1] "i= 1 Name = DAS"
## [1] "j= 2 Name = ALT"
## [1] "i= 1 Name = DAS"
## [1] "j= 3 Name = PEN"
## [1] "i= 1 Name = DAS"
## [1] "j= 4 Name = DEB"
## [1] "i= 1 Name = DAS"
## [1] "j= 5 Name = PH"
## [1] "i= 1 Name = DAS"
## [1] "j= 6 Name = DUR"
## [1] "i= 1 Name = DAS"
## [1] "j= 7 Name = PHO"
## [1] "i= 1 Name = DAS"
## [1] "j= 8 Name = NIT"
## [1] "i= 1 Name = DAS"
## [1] "j= 9 Name = AMM"
## [1] "i= 1 Name = DAS"
## [1] "j= 10 Name = OXY"
## [1] "i= 1 Name = DAS"
## [1] "j= 11 Name = DBO"
## [1] "i= 2 Name = ALT"
## [1] "j= 3 Name = PEN"
## [1] "i= 2 Name = ALT"
## [1] "j= 4 Name = DEB"
## [1] "i= 2 Name = ALT"
## [1] "j= 5 Name = PH"
## [1] "i= 2 Name = ALT"
## [1] "j= 6 Name = DUR"
## [1] "i= 2 Name = ALT"
## [1] "j= 7 Name = PHO"
## [1] "i= 2 Name = ALT"
## [1] "j= 8 Name = NIT"
## [1] "i= 2 Name = ALT"
## [1] "j= 9 Name = AMM"
## [1] "i= 2 Name = ALT"
## [1] "j= 10 Name = OXY"
## [1] "i= 2 Name = ALT"
## [1] "j= 11 Name = DBO"
## [1] "i= 3 Name = PEN"
## [1] "j= 4 Name = DEB"
## [1] "i= 3 Name = PEN"
## [1] "j= 5 Name = PH"
## [1] "i= 3 Name = PEN"
## [1] "j= 6 Name = DUR"
## [1] "i= 3 Name = PEN"
## [1] "j= 7 Name = PHO"
## [1] "i= 3 Name = PEN"
## [1] "j= 8 Name = NIT"
## [1] "i= 3 Name = PEN"
## [1] "j= 9 Name = AMM"
## [1] "i= 3 Name = PEN"
## [1] "j= 10 Name = OXY"
## [1] "i= 3 Name = PEN"
## [1] "j= 11 Name = DBO"
## [1] "i= 4 Name = DEB"
## [1] "j= 5 Name = PH"
## [1] "i= 4 Name = DEB"
## [1] "j= 6 Name = DUR"
## [1] "i= 4 Name = DEB"
## [1] "j= 7 Name = PHO"
## [1] "i= 4 Name = DEB"
## [1] "j= 8 Name = NIT"
## [1] "i= 4 Name = DEB"
## [1] "j= 9 Name = AMM"
## [1] "i= 4 Name = DEB"
## [1] "j= 10 Name = OXY"
## [1] "i= 4 Name = DEB"
## [1] "j= 11 Name = DBO"
## [1] "i= 5 Name = PH"
## [1] "j= 6 Name = DUR"
## [1] "i= 5 Name = PH"
## [1] "j= 7 Name = PHO"
## [1] "i= 5 Name = PH"
## [1] "j= 8 Name = NIT"
## [1] "i= 5 Name = PH"
## [1] "j= 9 Name = AMM"
## [1] "i= 5 Name = PH"
## [1] "j= 10 Name = OXY"
## [1] "i= 5 Name = PH"
## [1] "j= 11 Name = DBO"
## [1] "i= 6 Name = DUR"
## [1] "j= 7 Name = PHO"
## [1] "i= 6 Name = DUR"
## [1] "j= 8 Name = NIT"
## [1] "i= 6 Name = DUR"
## [1] "j= 9 Name = AMM"
## [1] "i= 6 Name = DUR"
## [1] "j= 10 Name = OXY"
## [1] "i= 6 Name = DUR"
## [1] "j= 11 Name = DBO"
## [1] "i= 7 Name = PHO"
## [1] "j= 8 Name = NIT"
## [1] "i= 7 Name = PHO"
## [1] "j= 9 Name = AMM"
## [1] "i= 7 Name = PHO"
## [1] "j= 10 Name = OXY"
## [1] "i= 7 Name = PHO"
## [1] "j= 11 Name = DBO"
## [1] "i= 8 Name = NIT"
## [1] "j= 9 Name = AMM"
## [1] "i= 8 Name = NIT"
## [1] "j= 10 Name = OXY"
## [1] "i= 8 Name = NIT"
## [1] "j= 11 Name = DBO"
## [1] "i= 9 Name = AMM"
## [1] "j= 10 Name = OXY"
## [1] "i= 9 Name = AMM"
## [1] "j= 11 Name = DBO"
## [1] "i= 10 Name = OXY"
## [1] "j= 11 Name = DBO"
The takes a long time - do over a break
fish.gbm1 <- dismo::gbm.step(data=cbind(fish, env) %>% as.data.frame, gbm.x=2:11, gbm.y=1,
tree.complexity=5,
learning.rate=0.001,
n.minobsinnode = 2,
bag.fraction=0.5,
n.train = 1,
n.trees=10000,
family='poisson')
summary(abalone.gbm1)
library(randomForest)
fish.rf = randomForest(fish ~ DAS + ALT + PEN + DEB + PH + DUR + PHO + NIT + AMM +
OXY + DBO,
data=env, importance=TRUE,
ntree=1000)
fish.imp = randomForest::importance(fish.rf)
## Rank by either:
## *MSE (mean decrease in accuracy)
## For each tree, calculate OOB prediction error.
## This also done after permuting predictors.
## Then average diff of prediction errors for each tree
## *NodePurity (mean decrease in node impurity)
## Measure of the total decline of impurity due to each
## predictor averaged over trees
100*fish.imp/sum(fish.imp)
## %IncMSE IncNodePurity
## DAS 0.10514608 19.578297
## ALT 0.08822959 15.527627
## PEN 0.03681748 7.102072
## DEB 0.08942111 18.600003
## PH -0.01288507 1.065126
## DUR 0.03792810 8.052133
## PHO 0.05512223 5.574289
## NIT 0.06791615 8.942640
## AMM 0.05865963 5.675618
## OXY 0.04979067 4.982862
## DBO 0.05360180 4.269585
varImpPlot(fish.rf)
## use brute force
fish.rf %>% pdp::partial('DAS') %>% autoplot
fish.rf.acc <- env %>%
bind_cols(Pred = predict(fish.rf,
newdata=env))
with(fish.rf.acc, cor(fish, Pred))
## [1] 0.9843458
fish.rf.acc %>%
ggplot() +
geom_point(aes(y=Pred, x=fish)) +
geom_point(data = fish.acc, aes(y=Pred, x=fish), colour = 'red')